from layer_naive import *

apple = 100
apple_num = 2

orange = 150
orange_num = 3

tax = 1.1

#layer
mul_apple_layer = MulLayer()
mul_orange_layer = MulLayer()
add_apple_orange_layer = AddLayer()
mul_tax_layer = MulLayer()

# forward
apple_price = mul_apple_layer.forward(apple,apple_num)
orange_price = mul_orange_layer.forward(orange,orange_num)
total_price = add_apple_orange_layer.forward(apple_price,orange_price)
price = mul_tax_layer.forward(total_price,tax)

#backward
dprice = 1
dtotal_price, dtax = mul_tax_layer.backword(dprice)
dapple_price,dorange_price = add_apple_orange_layer.backword(dtotal_price)
dorange,dorange_num = mul_orange_layer.backword(dorange_price)
dapple,dapple_num = mul_apple_layer.backword(dapple_price)

print("price:", int(price))
print("dApple:", dapple)
print("dApple_num:", int(dapple_num))
print("dOrange:", dorange)
print("dOrange_num:", int(dorange_num))
print("dTax:", dtax)